-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[mlir][Transforms] Fix crash in reconcile-unrealized-casts
#158067
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][Transforms] Fix crash in reconcile-unrealized-casts
#158067
Conversation
@llvm/pr-subscribers-mlir-memref Author: Matthias Springer (matthias-springer) ChangesThe Furthermore, the Full diff: https://github.com/llvm/llvm-project/pull/158067.diff 5 Files Affected:
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 36ee87b533b3b..f6a8e7e60a69c 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -3306,9 +3306,13 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
void mlir::reconcileUnrealizedCasts(
ArrayRef<UnrealizedConversionCastOp> castOps,
SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
+ // Set of all cast ops for faster lookups.
+ DenseSet<Operation *> castOpSet;
+ for (UnrealizedConversionCastOp op : castOps)
+ castOpSet.insert(op);
+
+ // A worklist of cast ops to process.
SetVector<UnrealizedConversionCastOp> worklist(llvm::from_range, castOps);
- // This set is maintained only if `remainingCastOps` is provided.
- DenseSet<Operation *> erasedOps;
// Helper function that adds all operands to the worklist that are an
// unrealized_conversion_cast op result.
@@ -3337,39 +3341,73 @@ void mlir::reconcileUnrealizedCasts(
// Process ops in the worklist bottom-to-top.
while (!worklist.empty()) {
UnrealizedConversionCastOp castOp = worklist.pop_back_val();
- if (castOp->use_empty()) {
- // DCE: If the op has no users, erase it. Add the operands to the
- // worklist to find additional DCE opportunities.
- enqueueOperands(castOp);
- if (remainingCastOps)
- erasedOps.insert(castOp.getOperation());
- castOp->erase();
- continue;
- }
// Traverse the chain of input cast ops to see if an op with the same
// input types can be found.
UnrealizedConversionCastOp nextCast = castOp;
while (nextCast) {
if (nextCast.getInputs().getTypes() == castOp.getResultTypes()) {
+ if (llvm::any_of(nextCast.getInputs(), [&](Value v) {
+ return v.getDefiningOp() == castOp;
+ })) {
+ // Ran into a cycle.
+ break;
+ }
+
// Found a cast where the input types match the output types of the
- // matched op. We can directly use those inputs and the matched op can
- // be removed.
+ // matched op. We can directly use those inputs.
enqueueOperands(castOp);
castOp.replaceAllUsesWith(nextCast.getInputs());
- if (remainingCastOps)
- erasedOps.insert(castOp.getOperation());
- castOp->erase();
break;
}
nextCast = getInputCast(nextCast);
}
}
- if (remainingCastOps)
- for (UnrealizedConversionCastOp op : castOps)
- if (!erasedOps.contains(op.getOperation()))
+ // A set of all alive cast ops. I.e., ops whose results are (transitively)
+ // used by an op that is not a cast op.
+ DenseSet<Operation *> liveOps;
+
+ // Helper function that marks the given op and all ops transitively reachable
+ // input cast ops as alive.
+ auto markOpLive = [&](Operation *op) {
+ SmallVector<Operation *> worklist;
+ worklist.push_back(op);
+ while (!worklist.empty()) {
+ Operation *op = worklist.pop_back_val();
+ if (liveOps.insert(op).second) {
+ // Successfully inserted: the op is live. Add its operands to the
+ // worklist to mark them live.
+ for (Value v : op->getOperands())
+ if (castOpSet.contains(v.getDefiningOp()))
+ worklist.push_back(v.getDefiningOp());
+ }
+ }
+ };
+
+ // Find all alive cast ops.
+ for (UnrealizedConversionCastOp op : castOps) {
+ // If any of the users is not a cast op, mark the current op (and its
+ // input ops) as live.
+ if (llvm::any_of(op->getUsers(), [&](Operation *user) {
+ return !castOpSet.contains(user);
+ }))
+ markOpLive(op);
+ }
+
+ // Erase all dead cast ops.
+ for (UnrealizedConversionCastOp op : castOps) {
+ if (liveOps.contains(op)) {
+ // Op is alive and was not erased. Add it to the remaining cast ops.
+ if (remainingCastOps)
remainingCastOps->push_back(op);
+ continue;
+ }
+
+ // Op is dead. Erase it.
+ op->dropAllUses();
+ op->erase();
+ }
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/ReconcileUnrealizedCasts/reconcile-unrealized-casts.mlir b/mlir/test/Conversion/ReconcileUnrealizedCasts/reconcile-unrealized-casts.mlir
index 3573114f5e038..ac5ca321c066f 100644
--- a/mlir/test/Conversion/ReconcileUnrealizedCasts/reconcile-unrealized-casts.mlir
+++ b/mlir/test/Conversion/ReconcileUnrealizedCasts/reconcile-unrealized-casts.mlir
@@ -194,3 +194,53 @@ func.func @emptyCast() -> index {
%0 = builtin.unrealized_conversion_cast to index
return %0 : index
}
+
+// -----
+
+// CHECK-LABEL: test.graph_region
+// CHECK-NEXT: "test.return"() : () -> ()
+test.graph_region {
+ %0 = builtin.unrealized_conversion_cast %2 : i32 to i64
+ %1 = builtin.unrealized_conversion_cast %0 : i64 to i16
+ %2 = builtin.unrealized_conversion_cast %1 : i16 to i32
+ "test.return"() : () -> ()
+}
+
+// -----
+
+// CHECK-LABEL: test.graph_region
+// CHECK-NEXT: %[[cast0:.*]] = builtin.unrealized_conversion_cast %[[cast2:.*]] : i32 to i64
+// CHECK-NEXT: %[[cast1:.*]] = builtin.unrealized_conversion_cast %[[cast0]] : i64 to i16
+// CHECK-NEXT: %[[cast2]] = builtin.unrealized_conversion_cast %[[cast1]] : i16 to i32
+// CHECK-NEXT: "test.user"(%[[cast2]]) : (i32) -> ()
+// CHECK-NEXT: "test.return"() : () -> ()
+test.graph_region {
+ %0 = builtin.unrealized_conversion_cast %2 : i32 to i64
+ %1 = builtin.unrealized_conversion_cast %0 : i64 to i16
+ %2 = builtin.unrealized_conversion_cast %1 : i16 to i32
+ "test.user"(%2) : (i32) -> ()
+ "test.return"() : () -> ()
+}
+
+// -----
+
+// CHECK-LABEL: test.graph_region
+// CHECK-NEXT: "test.return"() : () -> ()
+test.graph_region {
+ %0 = builtin.unrealized_conversion_cast %0 : i32 to i32
+ "test.return"() : () -> ()
+}
+
+// -----
+
+// CHECK-LABEL: test.graph_region
+// CHECK-NEXT: %[[c0:.*]] = arith.constant
+// CHECK-NEXT: %[[cast:.*]]:2 = builtin.unrealized_conversion_cast %[[c0]], %[[cast]]#1 : i32, i32 to i32, i32
+// CHECK-NEXT: "test.user"(%[[cast]]#0) : (i32) -> ()
+// CHECK-NEXT: "test.return"() : () -> ()
+test.graph_region {
+ %cst = arith.constant 0 : i32
+ %0, %1 = builtin.unrealized_conversion_cast %cst, %1 : i32, i32 to i32, i32
+ "test.user"(%0) : (i32) -> ()
+ "test.return"() : () -> ()
+}
diff --git a/mlir/test/Integration/Dialect/MemRef/assume-alignment-runtime-verification.mlir b/mlir/test/Integration/Dialect/MemRef/assume-alignment-runtime-verification.mlir
index 25a338df8d790..01a826a638606 100644
--- a/mlir/test/Integration/Dialect/MemRef/assume-alignment-runtime-verification.mlir
+++ b/mlir/test/Integration/Dialect/MemRef/assume-alignment-runtime-verification.mlir
@@ -1,7 +1,8 @@
// RUN: mlir-opt %s -generate-runtime-verification \
// RUN: -expand-strided-metadata \
// RUN: -test-cf-assert \
-// RUN: -convert-to-llvm | \
+// RUN: -convert-to-llvm \
+// RUN: -reconcile-unrealized-casts | \
// RUN: mlir-runner -e main -entry-point-result=void \
// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \
// RUN: FileCheck %s
diff --git a/mlir/test/Integration/Dialect/MemRef/atomic-rmw-runtime-verification.mlir b/mlir/test/Integration/Dialect/MemRef/atomic-rmw-runtime-verification.mlir
index 4c6a48d577a6c..1144a7caf36e8 100644
--- a/mlir/test/Integration/Dialect/MemRef/atomic-rmw-runtime-verification.mlir
+++ b/mlir/test/Integration/Dialect/MemRef/atomic-rmw-runtime-verification.mlir
@@ -1,6 +1,7 @@
// RUN: mlir-opt %s -generate-runtime-verification \
// RUN: -test-cf-assert \
-// RUN: -convert-to-llvm | \
+// RUN: -convert-to-llvm \
+// RUN: -reconcile-unrealized-casts | \
// RUN: mlir-runner -e main -entry-point-result=void \
// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \
// RUN: FileCheck %s
diff --git a/mlir/test/Integration/Dialect/MemRef/store-runtime-verification.mlir b/mlir/test/Integration/Dialect/MemRef/store-runtime-verification.mlir
index dd000c6904bcb..82e63805cd027 100644
--- a/mlir/test/Integration/Dialect/MemRef/store-runtime-verification.mlir
+++ b/mlir/test/Integration/Dialect/MemRef/store-runtime-verification.mlir
@@ -1,6 +1,7 @@
// RUN: mlir-opt %s -generate-runtime-verification \
// RUN: -test-cf-assert \
-// RUN: -convert-to-llvm | \
+// RUN: -convert-to-llvm \
+// RUN: -reconcile-unrealized-casts | \
// RUN: mlir-runner -e main -entry-point-result=void \
// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \
// RUN: FileCheck %s
|
@llvm/pr-subscribers-mlir Author: Matthias Springer (matthias-springer) ChangesThe Furthermore, the Full diff: https://github.com/llvm/llvm-project/pull/158067.diff 5 Files Affected:
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 36ee87b533b3b..f6a8e7e60a69c 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -3306,9 +3306,13 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
void mlir::reconcileUnrealizedCasts(
ArrayRef<UnrealizedConversionCastOp> castOps,
SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
+ // Set of all cast ops for faster lookups.
+ DenseSet<Operation *> castOpSet;
+ for (UnrealizedConversionCastOp op : castOps)
+ castOpSet.insert(op);
+
+ // A worklist of cast ops to process.
SetVector<UnrealizedConversionCastOp> worklist(llvm::from_range, castOps);
- // This set is maintained only if `remainingCastOps` is provided.
- DenseSet<Operation *> erasedOps;
// Helper function that adds all operands to the worklist that are an
// unrealized_conversion_cast op result.
@@ -3337,39 +3341,73 @@ void mlir::reconcileUnrealizedCasts(
// Process ops in the worklist bottom-to-top.
while (!worklist.empty()) {
UnrealizedConversionCastOp castOp = worklist.pop_back_val();
- if (castOp->use_empty()) {
- // DCE: If the op has no users, erase it. Add the operands to the
- // worklist to find additional DCE opportunities.
- enqueueOperands(castOp);
- if (remainingCastOps)
- erasedOps.insert(castOp.getOperation());
- castOp->erase();
- continue;
- }
// Traverse the chain of input cast ops to see if an op with the same
// input types can be found.
UnrealizedConversionCastOp nextCast = castOp;
while (nextCast) {
if (nextCast.getInputs().getTypes() == castOp.getResultTypes()) {
+ if (llvm::any_of(nextCast.getInputs(), [&](Value v) {
+ return v.getDefiningOp() == castOp;
+ })) {
+ // Ran into a cycle.
+ break;
+ }
+
// Found a cast where the input types match the output types of the
- // matched op. We can directly use those inputs and the matched op can
- // be removed.
+ // matched op. We can directly use those inputs.
enqueueOperands(castOp);
castOp.replaceAllUsesWith(nextCast.getInputs());
- if (remainingCastOps)
- erasedOps.insert(castOp.getOperation());
- castOp->erase();
break;
}
nextCast = getInputCast(nextCast);
}
}
- if (remainingCastOps)
- for (UnrealizedConversionCastOp op : castOps)
- if (!erasedOps.contains(op.getOperation()))
+ // A set of all alive cast ops. I.e., ops whose results are (transitively)
+ // used by an op that is not a cast op.
+ DenseSet<Operation *> liveOps;
+
+ // Helper function that marks the given op and all ops transitively reachable
+ // input cast ops as alive.
+ auto markOpLive = [&](Operation *op) {
+ SmallVector<Operation *> worklist;
+ worklist.push_back(op);
+ while (!worklist.empty()) {
+ Operation *op = worklist.pop_back_val();
+ if (liveOps.insert(op).second) {
+ // Successfully inserted: the op is live. Add its operands to the
+ // worklist to mark them live.
+ for (Value v : op->getOperands())
+ if (castOpSet.contains(v.getDefiningOp()))
+ worklist.push_back(v.getDefiningOp());
+ }
+ }
+ };
+
+ // Find all alive cast ops.
+ for (UnrealizedConversionCastOp op : castOps) {
+ // If any of the users is not a cast op, mark the current op (and its
+ // input ops) as live.
+ if (llvm::any_of(op->getUsers(), [&](Operation *user) {
+ return !castOpSet.contains(user);
+ }))
+ markOpLive(op);
+ }
+
+ // Erase all dead cast ops.
+ for (UnrealizedConversionCastOp op : castOps) {
+ if (liveOps.contains(op)) {
+ // Op is alive and was not erased. Add it to the remaining cast ops.
+ if (remainingCastOps)
remainingCastOps->push_back(op);
+ continue;
+ }
+
+ // Op is dead. Erase it.
+ op->dropAllUses();
+ op->erase();
+ }
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/ReconcileUnrealizedCasts/reconcile-unrealized-casts.mlir b/mlir/test/Conversion/ReconcileUnrealizedCasts/reconcile-unrealized-casts.mlir
index 3573114f5e038..ac5ca321c066f 100644
--- a/mlir/test/Conversion/ReconcileUnrealizedCasts/reconcile-unrealized-casts.mlir
+++ b/mlir/test/Conversion/ReconcileUnrealizedCasts/reconcile-unrealized-casts.mlir
@@ -194,3 +194,53 @@ func.func @emptyCast() -> index {
%0 = builtin.unrealized_conversion_cast to index
return %0 : index
}
+
+// -----
+
+// CHECK-LABEL: test.graph_region
+// CHECK-NEXT: "test.return"() : () -> ()
+test.graph_region {
+ %0 = builtin.unrealized_conversion_cast %2 : i32 to i64
+ %1 = builtin.unrealized_conversion_cast %0 : i64 to i16
+ %2 = builtin.unrealized_conversion_cast %1 : i16 to i32
+ "test.return"() : () -> ()
+}
+
+// -----
+
+// CHECK-LABEL: test.graph_region
+// CHECK-NEXT: %[[cast0:.*]] = builtin.unrealized_conversion_cast %[[cast2:.*]] : i32 to i64
+// CHECK-NEXT: %[[cast1:.*]] = builtin.unrealized_conversion_cast %[[cast0]] : i64 to i16
+// CHECK-NEXT: %[[cast2]] = builtin.unrealized_conversion_cast %[[cast1]] : i16 to i32
+// CHECK-NEXT: "test.user"(%[[cast2]]) : (i32) -> ()
+// CHECK-NEXT: "test.return"() : () -> ()
+test.graph_region {
+ %0 = builtin.unrealized_conversion_cast %2 : i32 to i64
+ %1 = builtin.unrealized_conversion_cast %0 : i64 to i16
+ %2 = builtin.unrealized_conversion_cast %1 : i16 to i32
+ "test.user"(%2) : (i32) -> ()
+ "test.return"() : () -> ()
+}
+
+// -----
+
+// CHECK-LABEL: test.graph_region
+// CHECK-NEXT: "test.return"() : () -> ()
+test.graph_region {
+ %0 = builtin.unrealized_conversion_cast %0 : i32 to i32
+ "test.return"() : () -> ()
+}
+
+// -----
+
+// CHECK-LABEL: test.graph_region
+// CHECK-NEXT: %[[c0:.*]] = arith.constant
+// CHECK-NEXT: %[[cast:.*]]:2 = builtin.unrealized_conversion_cast %[[c0]], %[[cast]]#1 : i32, i32 to i32, i32
+// CHECK-NEXT: "test.user"(%[[cast]]#0) : (i32) -> ()
+// CHECK-NEXT: "test.return"() : () -> ()
+test.graph_region {
+ %cst = arith.constant 0 : i32
+ %0, %1 = builtin.unrealized_conversion_cast %cst, %1 : i32, i32 to i32, i32
+ "test.user"(%0) : (i32) -> ()
+ "test.return"() : () -> ()
+}
diff --git a/mlir/test/Integration/Dialect/MemRef/assume-alignment-runtime-verification.mlir b/mlir/test/Integration/Dialect/MemRef/assume-alignment-runtime-verification.mlir
index 25a338df8d790..01a826a638606 100644
--- a/mlir/test/Integration/Dialect/MemRef/assume-alignment-runtime-verification.mlir
+++ b/mlir/test/Integration/Dialect/MemRef/assume-alignment-runtime-verification.mlir
@@ -1,7 +1,8 @@
// RUN: mlir-opt %s -generate-runtime-verification \
// RUN: -expand-strided-metadata \
// RUN: -test-cf-assert \
-// RUN: -convert-to-llvm | \
+// RUN: -convert-to-llvm \
+// RUN: -reconcile-unrealized-casts | \
// RUN: mlir-runner -e main -entry-point-result=void \
// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \
// RUN: FileCheck %s
diff --git a/mlir/test/Integration/Dialect/MemRef/atomic-rmw-runtime-verification.mlir b/mlir/test/Integration/Dialect/MemRef/atomic-rmw-runtime-verification.mlir
index 4c6a48d577a6c..1144a7caf36e8 100644
--- a/mlir/test/Integration/Dialect/MemRef/atomic-rmw-runtime-verification.mlir
+++ b/mlir/test/Integration/Dialect/MemRef/atomic-rmw-runtime-verification.mlir
@@ -1,6 +1,7 @@
// RUN: mlir-opt %s -generate-runtime-verification \
// RUN: -test-cf-assert \
-// RUN: -convert-to-llvm | \
+// RUN: -convert-to-llvm \
+// RUN: -reconcile-unrealized-casts | \
// RUN: mlir-runner -e main -entry-point-result=void \
// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \
// RUN: FileCheck %s
diff --git a/mlir/test/Integration/Dialect/MemRef/store-runtime-verification.mlir b/mlir/test/Integration/Dialect/MemRef/store-runtime-verification.mlir
index dd000c6904bcb..82e63805cd027 100644
--- a/mlir/test/Integration/Dialect/MemRef/store-runtime-verification.mlir
+++ b/mlir/test/Integration/Dialect/MemRef/store-runtime-verification.mlir
@@ -1,6 +1,7 @@
// RUN: mlir-opt %s -generate-runtime-verification \
// RUN: -test-cf-assert \
-// RUN: -convert-to-llvm | \
+// RUN: -convert-to-llvm \
+// RUN: -reconcile-unrealized-casts | \
// RUN: mlir-runner -e main -entry-point-result=void \
// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \
// RUN: FileCheck %s
|
@llvm/pr-subscribers-mlir-core Author: Matthias Springer (matthias-springer) ChangesThe Furthermore, the Full diff: https://github.com/llvm/llvm-project/pull/158067.diff 5 Files Affected:
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 36ee87b533b3b..f6a8e7e60a69c 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -3306,9 +3306,13 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
void mlir::reconcileUnrealizedCasts(
ArrayRef<UnrealizedConversionCastOp> castOps,
SmallVectorImpl<UnrealizedConversionCastOp> *remainingCastOps) {
+ // Set of all cast ops for faster lookups.
+ DenseSet<Operation *> castOpSet;
+ for (UnrealizedConversionCastOp op : castOps)
+ castOpSet.insert(op);
+
+ // A worklist of cast ops to process.
SetVector<UnrealizedConversionCastOp> worklist(llvm::from_range, castOps);
- // This set is maintained only if `remainingCastOps` is provided.
- DenseSet<Operation *> erasedOps;
// Helper function that adds all operands to the worklist that are an
// unrealized_conversion_cast op result.
@@ -3337,39 +3341,73 @@ void mlir::reconcileUnrealizedCasts(
// Process ops in the worklist bottom-to-top.
while (!worklist.empty()) {
UnrealizedConversionCastOp castOp = worklist.pop_back_val();
- if (castOp->use_empty()) {
- // DCE: If the op has no users, erase it. Add the operands to the
- // worklist to find additional DCE opportunities.
- enqueueOperands(castOp);
- if (remainingCastOps)
- erasedOps.insert(castOp.getOperation());
- castOp->erase();
- continue;
- }
// Traverse the chain of input cast ops to see if an op with the same
// input types can be found.
UnrealizedConversionCastOp nextCast = castOp;
while (nextCast) {
if (nextCast.getInputs().getTypes() == castOp.getResultTypes()) {
+ if (llvm::any_of(nextCast.getInputs(), [&](Value v) {
+ return v.getDefiningOp() == castOp;
+ })) {
+ // Ran into a cycle.
+ break;
+ }
+
// Found a cast where the input types match the output types of the
- // matched op. We can directly use those inputs and the matched op can
- // be removed.
+ // matched op. We can directly use those inputs.
enqueueOperands(castOp);
castOp.replaceAllUsesWith(nextCast.getInputs());
- if (remainingCastOps)
- erasedOps.insert(castOp.getOperation());
- castOp->erase();
break;
}
nextCast = getInputCast(nextCast);
}
}
- if (remainingCastOps)
- for (UnrealizedConversionCastOp op : castOps)
- if (!erasedOps.contains(op.getOperation()))
+ // A set of all alive cast ops. I.e., ops whose results are (transitively)
+ // used by an op that is not a cast op.
+ DenseSet<Operation *> liveOps;
+
+ // Helper function that marks the given op and all ops transitively reachable
+ // input cast ops as alive.
+ auto markOpLive = [&](Operation *op) {
+ SmallVector<Operation *> worklist;
+ worklist.push_back(op);
+ while (!worklist.empty()) {
+ Operation *op = worklist.pop_back_val();
+ if (liveOps.insert(op).second) {
+ // Successfully inserted: the op is live. Add its operands to the
+ // worklist to mark them live.
+ for (Value v : op->getOperands())
+ if (castOpSet.contains(v.getDefiningOp()))
+ worklist.push_back(v.getDefiningOp());
+ }
+ }
+ };
+
+ // Find all alive cast ops.
+ for (UnrealizedConversionCastOp op : castOps) {
+ // If any of the users is not a cast op, mark the current op (and its
+ // input ops) as live.
+ if (llvm::any_of(op->getUsers(), [&](Operation *user) {
+ return !castOpSet.contains(user);
+ }))
+ markOpLive(op);
+ }
+
+ // Erase all dead cast ops.
+ for (UnrealizedConversionCastOp op : castOps) {
+ if (liveOps.contains(op)) {
+ // Op is alive and was not erased. Add it to the remaining cast ops.
+ if (remainingCastOps)
remainingCastOps->push_back(op);
+ continue;
+ }
+
+ // Op is dead. Erase it.
+ op->dropAllUses();
+ op->erase();
+ }
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Conversion/ReconcileUnrealizedCasts/reconcile-unrealized-casts.mlir b/mlir/test/Conversion/ReconcileUnrealizedCasts/reconcile-unrealized-casts.mlir
index 3573114f5e038..ac5ca321c066f 100644
--- a/mlir/test/Conversion/ReconcileUnrealizedCasts/reconcile-unrealized-casts.mlir
+++ b/mlir/test/Conversion/ReconcileUnrealizedCasts/reconcile-unrealized-casts.mlir
@@ -194,3 +194,53 @@ func.func @emptyCast() -> index {
%0 = builtin.unrealized_conversion_cast to index
return %0 : index
}
+
+// -----
+
+// CHECK-LABEL: test.graph_region
+// CHECK-NEXT: "test.return"() : () -> ()
+test.graph_region {
+ %0 = builtin.unrealized_conversion_cast %2 : i32 to i64
+ %1 = builtin.unrealized_conversion_cast %0 : i64 to i16
+ %2 = builtin.unrealized_conversion_cast %1 : i16 to i32
+ "test.return"() : () -> ()
+}
+
+// -----
+
+// CHECK-LABEL: test.graph_region
+// CHECK-NEXT: %[[cast0:.*]] = builtin.unrealized_conversion_cast %[[cast2:.*]] : i32 to i64
+// CHECK-NEXT: %[[cast1:.*]] = builtin.unrealized_conversion_cast %[[cast0]] : i64 to i16
+// CHECK-NEXT: %[[cast2]] = builtin.unrealized_conversion_cast %[[cast1]] : i16 to i32
+// CHECK-NEXT: "test.user"(%[[cast2]]) : (i32) -> ()
+// CHECK-NEXT: "test.return"() : () -> ()
+test.graph_region {
+ %0 = builtin.unrealized_conversion_cast %2 : i32 to i64
+ %1 = builtin.unrealized_conversion_cast %0 : i64 to i16
+ %2 = builtin.unrealized_conversion_cast %1 : i16 to i32
+ "test.user"(%2) : (i32) -> ()
+ "test.return"() : () -> ()
+}
+
+// -----
+
+// CHECK-LABEL: test.graph_region
+// CHECK-NEXT: "test.return"() : () -> ()
+test.graph_region {
+ %0 = builtin.unrealized_conversion_cast %0 : i32 to i32
+ "test.return"() : () -> ()
+}
+
+// -----
+
+// CHECK-LABEL: test.graph_region
+// CHECK-NEXT: %[[c0:.*]] = arith.constant
+// CHECK-NEXT: %[[cast:.*]]:2 = builtin.unrealized_conversion_cast %[[c0]], %[[cast]]#1 : i32, i32 to i32, i32
+// CHECK-NEXT: "test.user"(%[[cast]]#0) : (i32) -> ()
+// CHECK-NEXT: "test.return"() : () -> ()
+test.graph_region {
+ %cst = arith.constant 0 : i32
+ %0, %1 = builtin.unrealized_conversion_cast %cst, %1 : i32, i32 to i32, i32
+ "test.user"(%0) : (i32) -> ()
+ "test.return"() : () -> ()
+}
diff --git a/mlir/test/Integration/Dialect/MemRef/assume-alignment-runtime-verification.mlir b/mlir/test/Integration/Dialect/MemRef/assume-alignment-runtime-verification.mlir
index 25a338df8d790..01a826a638606 100644
--- a/mlir/test/Integration/Dialect/MemRef/assume-alignment-runtime-verification.mlir
+++ b/mlir/test/Integration/Dialect/MemRef/assume-alignment-runtime-verification.mlir
@@ -1,7 +1,8 @@
// RUN: mlir-opt %s -generate-runtime-verification \
// RUN: -expand-strided-metadata \
// RUN: -test-cf-assert \
-// RUN: -convert-to-llvm | \
+// RUN: -convert-to-llvm \
+// RUN: -reconcile-unrealized-casts | \
// RUN: mlir-runner -e main -entry-point-result=void \
// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \
// RUN: FileCheck %s
diff --git a/mlir/test/Integration/Dialect/MemRef/atomic-rmw-runtime-verification.mlir b/mlir/test/Integration/Dialect/MemRef/atomic-rmw-runtime-verification.mlir
index 4c6a48d577a6c..1144a7caf36e8 100644
--- a/mlir/test/Integration/Dialect/MemRef/atomic-rmw-runtime-verification.mlir
+++ b/mlir/test/Integration/Dialect/MemRef/atomic-rmw-runtime-verification.mlir
@@ -1,6 +1,7 @@
// RUN: mlir-opt %s -generate-runtime-verification \
// RUN: -test-cf-assert \
-// RUN: -convert-to-llvm | \
+// RUN: -convert-to-llvm \
+// RUN: -reconcile-unrealized-casts | \
// RUN: mlir-runner -e main -entry-point-result=void \
// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \
// RUN: FileCheck %s
diff --git a/mlir/test/Integration/Dialect/MemRef/store-runtime-verification.mlir b/mlir/test/Integration/Dialect/MemRef/store-runtime-verification.mlir
index dd000c6904bcb..82e63805cd027 100644
--- a/mlir/test/Integration/Dialect/MemRef/store-runtime-verification.mlir
+++ b/mlir/test/Integration/Dialect/MemRef/store-runtime-verification.mlir
@@ -1,6 +1,7 @@
// RUN: mlir-opt %s -generate-runtime-verification \
// RUN: -test-cf-assert \
-// RUN: -convert-to-llvm | \
+// RUN: -convert-to-llvm \
+// RUN: -reconcile-unrealized-casts | \
// RUN: mlir-runner -e main -entry-point-result=void \
// RUN: -shared-libs=%mlir_runner_utils 2>&1 | \
// RUN: FileCheck %s
|
66a0f0b
to
c357eb1
Compare
96cf8f3
to
a1db187
Compare
a1db187
to
f7686bd
Compare
namespace mlir { | ||
|
||
// Predeclaration only. | ||
static void reconcileUnrealizedCasts( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added this predeclaration here to keep the diff small, so that the PR is easier to review. Will move the entire function here in a follow-up NFC PR.
Co-authored-by: Mehdi Amini <joker.eph@gmail.com>
✅ With the latest revision this PR passed the C/C++ code formatter. |
1bb194a
to
6183b1d
Compare
LLVM Buildbot has detected a new failure on builder Full details are available at: https://lab.llvm.org/buildbot/#/builders/138/builds/18986 Here is the relevant piece of the build log for the reference
|
The bot failure looks legit to me, I reverted while you look into it. |
The `reconcile-unrealized-casts` pass used to crash when the input contains circular chains of `unrealized_conversion_cast` ops. Furthermore, the `reconcileUnrealizedCasts` helper functions used to erase ops that were not passed via the `castOps` operand. Such ops are now preserved. That's why some integration tests had to be changed. Also avoid copying the set of all unresolved materializations in `convertOperations`. This commit is in preparation of turning `RewriterBase::replaceOp` into a non-virtual function. --------- Co-authored-by: Mehdi Amini <joker.eph@gmail.com>
The `reconcile-unrealized-casts` pass used to crash when the input contains circular chains of `unrealized_conversion_cast` ops. Furthermore, the `reconcileUnrealizedCasts` helper functions used to erase ops that were not passed via the `castOps` operand. Such ops are now preserved. That's why some integration tests had to be changed. Also avoid copying the set of all unresolved materializations in `convertOperations`. This commit is in preparation of turning `RewriterBase::replaceOp` into a non-virtual function. This is a re-upload of #158067, which was reverted due to CI failures. Note for LLVM integration: If you are seeing tests that are failing with `error: LLVM Translation failed for operation: builtin.unrealized_conversion_cast`, you may have to add the `-reconcile-unrealized-casts` pass to your pass pipeline. (Or switch to the `-convert-to-llvm` pass instead of combining the various `-convert-*-to-llvm` passes.) --------- Co-authored-by: Mehdi Amini <joker.eph@gmail.com>
…zed-casts`" (#158295) Reverts llvm/llvm-project#158067 Buildbot is broken.
The
reconcile-unrealized-casts
pass used to crash when the input contains circular chains ofunrealized_conversion_cast
ops.Furthermore, the
reconcileUnrealizedCasts
helper functions used to erase ops that were not passed via thecastOps
operand. Such ops are now preserved. That's why some integration tests had to be changed.Also avoid copying the set of all unresolved materializations in
convertOperations
.This commit is in preparation of turning
RewriterBase::replaceOp
into a non-virtual function.